Goto

Collaborating Authors

 correlation shift


Conditional Mutual Information for Disentangled Representations in Reinforcement Learning

Neural Information Processing Systems

Reinforcement Learning (RL) environments can produce training data with spurious correlations between features due to the amount of training data or its limited feature coverage. This can lead to RL agents encoding these misleading correlations in their latent representation, preventing the agent from generalising if the correlation changes within the environment or when deployed in the real world. Disentangled representations can improve robustness, but existing disentanglement techniques that minimise mutual information between features require independent features, thus they cannot disentangle correlated features. We propose an auxiliary task for RL algorithms that learns a disentangled representation of high-dimensional observations with correlated features by minimising the conditional mutual information between features in the representation. We demonstrate experimentally, using continuous control tasks, that our approach improves generalisation under correlation shifts, as well as improving the training performance of RL algorithms in the presence of correlated features.








Bayesian Cross-Modal Alignment Learning for Few-Shot Out-of-Distribution Generalization

arXiv.org Artificial Intelligence

Recent advances in large pre-trained models showed promising results in few-shot learning. However, their generalization ability on two-dimensional Out-of-Distribution (OoD) data, i.e., correlation shift and diversity shift, has not been thoroughly investigated. Researches have shown that even with a significant amount of training data, few methods can achieve better performance than the standard empirical risk minimization method (ERM) in OoD generalization. This few-shot OoD generalization dilemma emerges as a challenging direction in deep neural network generalization research, where the performance suffers from overfitting on few-shot examples and OoD generalization errors. In this paper, leveraging a broader supervision source, we explore a novel Bayesian cross-modal image-text alignment learning method (Bayes-CAL) to address this issue. Specifically, the model is designed as only text representations are fine-tuned via a Bayesian modelling approach with gradient orthogonalization loss and invariant risk minimization (IRM) loss. The Bayesian approach is essentially introduced to avoid overfitting the base classes observed during training and improve generalization to broader unseen classes. The dedicated loss is introduced to achieve better image-text alignment by disentangling the causal and non-casual parts of image features. Numerical experiments demonstrate that Bayes-CAL achieved state-of-the-art OoD generalization performances on two-dimensional distribution shifts. Moreover, compared with CLIP-like models, Bayes-CAL yields more stable generalization performances on unseen classes. Our code is available at https://github.com/LinLLLL/BayesCAL.


CATS: Mitigating Correlation Shift for Multivariate Time Series Classification

arXiv.org Machine Learning

Unsupervised Domain Adaptation (UDA) leverages labeled source data to train models for unlabeled target data. Given the prevalence of multivariate time series (MTS) data across various domains, the UDA task for MTS classification has emerged as a critical challenge. However, for MTS data, correlations between variables often vary across domains, whereas most existing UDA works for MTS classification have overlooked this essential characteristic. To bridge this gap, we introduce a novel domain shift, {\em correlation shift}, measuring domain differences in multivariate correlation. To mitigate correlation shift, we propose a scalable and parameter-efficient \underline{C}orrelation \underline{A}dapter for M\underline{TS} (CATS). Designed as a plug-and-play technique compatible with various Transformer variants, CATS employs temporal convolution to capture local temporal patterns and a graph attention module to model the changing multivariate correlation. The adapter reweights the target correlations to align the source correlations with a theoretically guaranteed precision. A correlation alignment loss is further proposed to mitigate correlation shift, bypassing the alignment challenge from the non-i.i.d. nature of MTS data. Extensive experiments on four real-world datasets demonstrate that (1) compared with vanilla Transformer-based models, CATS increases over $10\%$ average accuracy while only adding around $1\%$ parameters, and (2) all Transformer variants equipped with CATS either reach or surpass state-of-the-art baselines.


Learning Fair Invariant Representations under Covariate and Correlation Shifts Simultaneously

arXiv.org Artificial Intelligence

Achieving the generalization of an invariant classifier from training domains to shifted test domains while simultaneously considering model fairness is a substantial and complex challenge in machine learning. Existing methods address the problem of fairness-aware domain generalization, focusing on either covariate shift or correlation shift, but rarely consider both at the same time. In this paper, we introduce a novel approach that focuses on learning a fairness-aware domain-invariant predictor within a framework addressing both covariate and correlation shifts simultaneously, ensuring its generalization to unknown test domains inaccessible during training. In our approach, data are first disentangled into content and style factors in latent spaces. Furthermore, fairness-aware domain-invariant content representations can be learned by mitigating sensitive information and retaining as much other information as possible. Extensive empirical studies on benchmark datasets demonstrate that our approach surpasses state-of-the-art methods with respect to model accuracy as well as both group and individual fairness.